import os
import sys
import math
import torch
import os.path as osp
import torchvision.utils as tvu

sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from tqdm import tqdm
from diffusion_tools.diffusion.diffusion import GaussianDiffusion

__all__ = ['GaussianDiffusion', 'beta_schedule']

def kl_divergence(mu1, logvar1, mu2, logvar2):
    return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mu1 - mu2) ** 2) * torch.exp(-logvar2))

def standard_normal_cdf(x):
    r"""A fast approximation of the cumulative distribution function of the standard normal.
    """
    return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

def discretized_gaussian_log_likelihood(x0, mean, log_scale):
    assert x0.shape == mean.shape
    cx = x0 - mean
    inv_stdv = torch.exp(-log_scale)
    cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
    cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
    log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
    log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
    cdf_delta = cdf_plus - cdf_min
    log_probs = torch.where(
        x0 < -0.999,
        log_cdf_plus,
        torch.where(x0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))))
    assert log_probs.shape == x0.shape
    return log_probs

def _i(tensor, t, x):
    r"""Index tensor using t and format the output according to x.
    """
    shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
    return tensor[t.cpu()].view(shape).to(x)

def beta_schedule(schedule, num_timesteps=1000, init_beta=None, last_beta=None):
    if schedule == 'linear':
        scale = 1000.0 / num_timesteps
        init_beta = init_beta or scale * 0.0001
        last_beta = last_beta or scale * 0.02
        return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64)
    elif schedule == 'quadratic':
        init_beta = init_beta or 0.0015
        last_beta = last_beta or 0.0195
        return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
    elif schedule == 'cosine':
        betas = []
        for step in range(num_timesteps):
            t1 = step / num_timesteps
            t2 = (step + 1) / num_timesteps
            fn = lambda u: math.cos((u + 0.008) / 1.008 * math.pi / 2) ** 2
            betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
        return torch.tensor(betas, dtype=torch.float64)
    elif schedule == "cosine_shift":
        betas = []
        for step in range(num_timesteps):
            t1 = step / num_timesteps
            t2 = (step + 1) / num_timesteps
            gn = lambda u: math.tan((u + 0.008) / 1.008 * math.pi / 2)
            snr_ = lambda u: -2 * torch.log(torch.tensor(gn(u))) + 2 * torch.log(torch.tensor(1/4))
            alpha_ = lambda u: torch.sigmoid(snr_(u))
            betas.append(min(1.0 - alpha_(t2) / alpha_(t1), 0.999))

        return torch.tensor(betas, dtype=torch.float64)
    elif schedule == 'linear_zero':
        scale = 1000.0 / num_timesteps
        last_beta = last_beta or scale * 0.02
        return torch.linspace(0., last_beta, num_timesteps + 1, dtype=torch.float64)[1:]
    elif schedule == 'quadratic_zero':
        last_beta = last_beta or 0.0195
        return torch.linspace(0., last_beta ** 0.5, num_timesteps + 1, dtype=torch.float64)[1:] ** 2
    elif schedule == 'cosine_zero':
        betas = []
        for step in range(num_timesteps):
            t1 = step / num_timesteps
            t2 = (step + 1) / num_timesteps
            fn = lambda u: math.cos(u * math.pi / 2) ** 2
            betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
        return torch.tensor(betas, dtype=torch.float64)
    elif schedule == "cosine_shift_zero":
        betas = []
        for step in range(num_timesteps):
            t1 = step / num_timesteps
            t2 = (step + 1) / num_timesteps
            gn = lambda u: math.tan(u * math.pi / 2)
            snr_ = lambda u: -2 * torch.log(torch.tensor(gn(u))) + 2 * torch.log(torch.tensor(1/4))
            alpha_ = lambda u: torch.sigmoid(snr_(u))
            betas.append(min(1.0 - alpha_(t2) / alpha_(t1), 0.999))
        return torch.tensor(betas, dtype=torch.float64)
    else:
        raise ValueError(f'Unsupported schedule: {schedule}')

def compression_schedule(schedule, num_timesteps=1000, n_compressed_timeranges:int=None):
    n_compressed_timeranges = n_compressed_timeranges or num_timesteps 
    # split schedule
    if schedule.split('_')[0] == 'linear':
        compressed_timerange_endsteps = torch.linspace(0, num_timesteps, n_compressed_timeranges + 1)[1:].to(dtype = torch.long)
    elif schedule.split('_')[0] == 'quad':
        assert n_compressed_timeranges <= 50
        c = num_timesteps / n_compressed_timeranges ** 2
        compressed_timerange_endsteps = [math.floor(c * i ** 2) for i in range(n_compressed_timeranges + 1)]
        compressed_timerange_endsteps = torch.tensor(compressed_timerange_endsteps, dtype=torch.long)[1:]
        if n_compressed_timeranges == 50 and compressed_timerange_endsteps[0] == 0:
            compressed_timerange_endsteps[:4] = torch.tensor([2, 4, 6, 8], dtype=torch.long)
    elif schedule.split('_')[0] == 'quadr':
        assert n_compressed_timeranges <= 50
        c = num_timesteps / n_compressed_timeranges ** 2
        compressed_timerange_endsteps = [math.floor(num_timesteps - c * i ** 2) for i in range(n_compressed_timeranges, -1, -1)]
        compressed_timerange_endsteps = torch.tensor(compressed_timerange_endsteps, dtype=torch.long)[1:]
    else:
        raise NotImplementedError()

    # conditioned timesteps
    if schedule.split('_')[-1] == 'start':
         cond_timesteps = torch.cat([compressed_timerange_endsteps.new_zeros([1]), compressed_timerange_endsteps[:-1]])
    else:
        raise NotImplementedError()

    return compressed_timerange_endsteps, cond_timesteps

class CompressedTDiffusion(GaussianDiffusion):
    def __init__(self,
                 betas,
                 compressed_timerange_endsteps,
                 cond_timesteps,
                 start_condt=-1,
                 end_condt=1000,
                 mean_type='eps',
                 var_type='learned_range',
                 loss_type='mse',
                 rescale_timesteps=False):
        super().__init__(betas, mean_type, var_type, loss_type, rescale_timesteps)
        self.compressed_timerange_endsteps = compressed_timerange_endsteps
        self.cond_timesteps = cond_timesteps 
        self.n_compressed_timeranges = len(self.compressed_timerange_endsteps)
        self.start_condt = start_condt
        self.end_condt = end_condt
        assert len(self.start_condt) == len(self.end_condt)
        self.tlist_length = len(self.start_condt)
        assert self.start_condt <= self.end_condt
    
    def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None):
        r"""Distribution of p(x_{t-1} | x_t).
        """
        # predict distribution
        for i in range(self.tlist_length):
            if t[0] >= self.start_condt[i] and t[0] <= self.end_condt[i]:
                t_cond = torch.sum(t.view(-1, 1 ) >= self.compressed_timerange_endsteps.view(1, -1).to(t.device), dim=1)
                t_cond = self.cond_timesteps.to(t.device)[t_cond]
                break
            else:
                t_cond = t
        # print (t_cond, t)
        if guide_scale is None:
            out = model(xt, self._scale_timesteps(t_cond), **model_kwargs)
        else:
            assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
            y_out = model(xt, self._scale_timesteps(t_cond), **model_kwargs[0])
            u_out = model(xt, self._scale_timesteps(t_cond), **model_kwargs[1])
            dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2
            out = torch.cat([
                u_out[:, :dim] + guide_scale * (y_out[:, :dim] - u_out[:, :dim]),
                y_out[:, dim:]], dim=1)

        # compute variance
        if self.var_type == 'learned':
            out, log_var = out.chunk(2, dim=1)
            var = torch.exp(log_var)
        elif self.var_type == 'learned_range':
            out, fraction = out.chunk(2, dim=1)
            min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
            max_log_var = _i(torch.log(self.betas), t, xt)
            fraction = (fraction + 1) / 2.0
            log_var = fraction * max_log_var + (1 - fraction) * min_log_var
            var = torch.exp(log_var)
        elif self.var_type == 'fixed_large':
            var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt)
            log_var = torch.log(var)
        elif self.var_type == 'fixed_small':
            var = _i(self.posterior_variance, t, xt)
            log_var = _i(self.posterior_log_variance_clipped, t, xt)
              
        # compute mean and x0
        if self.mean_type == 'x_{t-1}':
            mu = out  # x_{t-1}
            x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
                 _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt
        elif self.mean_type == 'x0':
            x0 = out
            mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
        elif self.mean_type == 'eps':
            x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                 _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
            if clamp is not None:
                x0 = x0.clamp(-clamp, clamp)
            mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
        elif self.mean_type == 'v':
            x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \
                 _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out
            if clamp is not None:
                x0 = x0.clamp(-clamp, clamp)
            mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
        
        # restrict the range of x0
        if percentile is not None:
            assert percentile > 0 and percentile <= 1  # e.g., 0.995
            s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1)
            x0 = torch.min(s, torch.max(-s, x0)) / s
        elif clamp is not None:
            x0 = x0.clamp(-clamp, clamp)
        return mu, var, log_var, x0

    def loss(self, x0, t, model, model_kwargs={}, noise=None):
        noise = torch.randn_like(x0) if noise is None else noise
        xt = self.q_sample(x0, t, noise=noise)

        # compute loss
        if self.loss_type in ['kl', 'rescaled_kl']:
            loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs)
            if self.loss_type == 'rescaled_kl':
                loss = loss * self.num_timesteps
        elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1', 'mse_mean']:
            t_cond = torch.sum(t.view(-1, 1) >= self.compressed_timerange_endsteps.view(1, -1).to(t.device), dim=1)
            t_cond = self.cond_timesteps.to(t.device)[t_cond]
            t_ = t.clone()
            for k in range(self.tlist_length):
                for i in range(t.size(0)):
                    if t[i] <= self.end_condt[k] and t[i] >= self.start_condt[k]:
                        t_[i] = t_cond[i]
            t_cond = t_
            # print (t_cond, t)

            out = model(xt, self._scale_timesteps(t_cond), **model_kwargs)

            # VLB for variation
            loss_vlb = 0.0
            if self.var_type in ['learned', 'learned_range']:
                out, var = out.chunk(2, dim=1)
                frozen = torch.cat([out.detach(), var], dim=1)  # learn var without affecting the prediction of mean
                loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen)
                if self.loss_type.startswith('rescaled_'):
                    loss_vlb = loss_vlb * self.num_timesteps / 1000.0
            
            # MSE/L1 for x0/eps
            target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0], 'v': self.v_prediction_groundtruth(noise, x0, t)}[self.mean_type]
            loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1)
            if self.loss_type == 'mse_mean':
                mean_diff = (out.mean() - target.mean()).pow(2)
                var_target = (target - target.mean()).pow(2).mean()
                out_target = (out - out.mean()).pow(2).mean()
                var_diff = (var_target - out_target).pow(2)
                loss = loss + 10 * (mean_diff + var_diff)
            
            # total loss
            loss = loss + loss_vlb
        elif self.loss_type in ['mse_sqrt']:
            out = model(xt, self._scale_timesteps(t), **model_kwargs)
            target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0], 'v': self.v_prediction_groundtruth(noise, x0, t)}[self.mean_type]
            loss = torch.nn.functional.mse_loss(target, out)
        return loss

    @torch.no_grad()
    def ddim_sample(self, xt, t, t_pre, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0):
        r"""Sample from p(x_{t-1} | x_t) using DDIM.
            - condition_fn: for classifier-based guidance (guided-diffusion).
            - guide_scale: for classifier-free guidance (glide/dalle-2).
        """
        _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
        if condition_fn is not None:
            # x0 -> eps
            alpha = _i(self.alphas_cumprod, t, xt)
            eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
                  _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
            eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs)

            # eps -> x0
            x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                 _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
        
        # derive variables
        eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
              _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
        alphas = _i(self.alphas_cumprod, t, xt)
        # alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
        alphas_prev = _i(self.alphas_cumprod, t_pre.clamp(0), xt)
        sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))

        # random sample
        noise = torch.randn_like(xt)
        direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps
        mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
        xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
        return xt_1, x0

    @torch.no_grad()
    def ddim_sample_loop(self, noise, model, model_kwargs={}, clamp=None, scondt=-1, econdt=1000, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, progress_desc=None):
        # prepare input
        b = noise.size(0)
        xt = noise
        
        start_steps = (1 + torch.arange(0, scondt, 1)).clamp(0, self.num_timesteps - 1).flip(0) if scondt > 1 else None
        skip_steps = (1 + torch.arange(scondt, econdt, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
        skip_steps_cond = skip_steps
        end_steps = (1 + torch.arange(econdt, self.num_timesteps, 1)).clamp(0, self.num_timesteps - 1).flip(0) if econdt < 999 else None
        if end_steps is not None: 
            steps = torch.cat([end_steps, skip_steps_cond], dim=0)
        else:
            steps = skip_steps_cond
        if start_steps is not None:
            steps = torch.cat([steps, start_steps], dim=0)
        else:
            steps = steps
        pre_steps = torch.cat([steps[1:], torch.zeros_like(steps)[:1]], dim=0)
        # print (steps, pre_steps)
        steps_iter = enumerate(zip(pre_steps, steps))
        if progress_desc:
            steps_iter = tqdm(steps_iter, desc=progress_desc, total=len(steps))

        for k, (pre_step, step) in steps_iter:
            t_pre = torch.full((b, ), pre_step, dtype=torch.long, device=xt.device)
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, _ = self.ddim_sample(xt, t, t_pre, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta)
        return xt